-
Notifications
You must be signed in to change notification settings - Fork 540
[PyTorch] FSDP2 Support for TE #2245
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Varun Thumbe <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <[email protected]>
…ngine into fsdp2_issue_fix Signed-off-by: Varun Thumbe <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <[email protected]>
…ngine into fsdp2_issue_fix
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <[email protected]>
…rgst Signed-off-by: Varun Thumbe <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>
…es when required instead of doing upfront in fwd pass Signed-off-by: Varun Thumbe <[email protected]>
…ling in fsdp hook functions Signed-off-by: Varun Thumbe <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>
Co-authored-by: Tim Moon <[email protected]> Signed-off-by: vthumbe1503 <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR adds FSDP2 (Fully Sharded Data Parallel 2) support to Transformer Engine, enabling distributed training with FP8 quantization. The implementation includes:
- Core Integration: Added
fsdp_pre_all_gather()andfsdp_post_all_gather()hooks toFloat8TensorandMXFP8Tensorfor weight sharding/gathering during forward/backward passes - Tensor Operations: Implemented FSDP2-compatible
__torch_dispatch__handlers for split, slice, as_strided, copy_, and new_zeros operations with proper handling of quantized data and transpose caches - DTensor Support: Enhanced
TransformerEngineBaseModule.reset_parameters()to handle DTensor parameters, including proper quantizer configuration with amax reduction across device mesh - Test Coverage: Expanded test suite with multiple FP8 scaling recipes (delayed, current, MX_FP8) and sharding configurations
Key Implementation Details:
- Forward pass uses rowwise data representation; backward pass uses columnwise (transpose) for optimal Tensor Core performance
- Transpose cache is maintained across tensor operations to avoid recomputation
- Amax reduction is configured across FSDP mesh for consistent scaling across shards
Issues from Previous Comments:
Several critical issues from earlier reviews remain unresolved in mxfp8_tensor.py, particularly around None handling in the slice.Tensor dispatch handler (line 479) where out_data[0].shape assumes rowwise_data exists.
Confidence Score: 3/5
- This PR requires fixes before merging - critical None handling issues remain from previous reviews
- Score reflects unresolved critical issues from previous review comments. The slice.Tensor handler at mxfp8_tensor.py:479 accesses
out_data[0].shapebutout_data[0]can be None whenrowwise_datais None, causing AttributeError. Similar patterns exist in other handlers. The Float8Tensor implementation appears more robust with better None handling. Core FSDP2 integration logic is sound but needs defensive programming for edge cases. - Pay close attention to
transformer_engine/pytorch/tensor/mxfp8_tensor.py- the slice.Tensor, split.Tensor, and as_strided handlers need None-safety fixes for cases where rowwise_data or columnwise_data is None
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/pytorch/tensor/mxfp8_tensor.py | 3/5 | Added FSDP2 dispatch handlers (split, as_strided, copy_, slice, new_zeros) and pre/post all-gather hooks. Previous comments flagged None handling and AttributeError risks that remain unaddressed. |
| transformer_engine/pytorch/tensor/float8_tensor.py | 4/5 | Added FSDP2 support with transpose cache handling in split/new_zeros/as_strided ops, plus pre/post all-gather hooks for distributed training. Good transpose cache maintenance logic. |
| transformer_engine/pytorch/module/base.py | 4/5 | Enhanced reset_parameters to handle DTensor (FSDP2), preserving local tensor quantization while maintaining distributed mesh metadata. Properly handles amax reduction groups. |
| transformer_engine/pytorch/distributed.py | 3/5 | Added _get_module_fsdp_state helper with @lru_cache. Previous comment noted potential stale cache issues during training state changes. |
| tests/pytorch/distributed/run_fsdp2_model.py | 4/5 | Expanded FSDP2 test runner with multiple scaling recipes (delayed, current, MX_FP8), sharding configurations, and distributed test cases. |
Sequence Diagram
sequenceDiagram
participant User as Training Script
participant FSDP as FSDP2 Manager
participant Module as TE Module
participant FP8 as Float8Tensor/MXFP8Tensor
participant Quantizer as Quantizer
Note over User,Quantizer: Initialization Phase
User->>Module: fully_shard(te_module)
Module->>Module: register_parameter()
Note over Module: Skip param_init_meta if already exists
User->>Module: reset_parameters()
Module->>Module: Check if param is DTensor
alt Is DTensor
Module->>Module: Extract _local_tensor
Module->>Quantizer: Configure amax_reduction_group
Module->>FP8: Quantize local tensor
Module->>Module: Wrap back to DTensor
else Regular Tensor
Module->>FP8: Quantize tensor
end
Note over User,Quantizer: Forward Pass
FSDP->>FP8: fsdp_pre_all_gather(mesh, orig_size, module, mp_policy)
FP8->>Quantizer: Configure rowwise/columnwise usage
Note over FP8: Set usage based on training state
FP8->>FSDP: Return (sharded_tensors, metadata)
FSDP->>FSDP: AllGather sharded_tensors
FSDP->>FP8: fsdp_post_all_gather(all_gather_outputs, metadata, param_dtype)
FP8->>FP8: Reconstruct from gathered data
FP8->>FSDP: Return reconstructed Float8Tensor
FSDP->>Module: forward() with all-gathered weights
Note over User,Quantizer: Backward Pass
FSDP->>FP8: fsdp_pre_all_gather() [backward]
Note over FP8: Configure columnwise usage for backward
FP8->>FSDP: Return (sharded_tensors, metadata)
FSDP->>FSDP: AllGather for backward
FSDP->>FP8: fsdp_post_all_gather()
FP8->>FSDP: Return reconstructed tensor
Note over User,Quantizer: Gradient Sync
FSDP->>FSDP: ReduceScatter gradients
FSDP->>Module: Update sharded weights
11 files reviewed, no comments
Signed-off-by: Varun Thumbe <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR adds comprehensive FSDP2 (Fully Sharded Data Parallel v2) support to Transformer Engine, enabling distributed training with FP8 quantization across multiple scaling recipes.
Key Changes:
- Implemented
fsdp_pre_all_gather()andfsdp_post_all_gather()hooks inFloat8TensorandMXFP8Tensorfor FSDP2 integration - Added torch dispatch handlers for FSDP2 tensor operations (
split,as_strided,slice,copy_,new_zeros) - Enhanced transpose cache management during tensor reshaping operations for improved performance
- Added training state-aware quantizer usage selection (rowwise for forward, columnwise for backward)
- Modified TE modules to detect pre-quantized weights and skip redundant quantizer configuration
- Expanded test coverage with multiple recipes (delayed scaling, current scaling, MX block scaling) and layer types
Architecture:
FSDP2 shards quantized tensors across ranks. During forward/backward, fsdp_pre_all_gather() extracts sharded data and metadata, FSDP2 performs all-gather, then fsdp_post_all_gather() reconstructs the full quantized tensor with proper transpose caches based on training state.
Confidence Score: 4/5
- Safe to merge with minor concerns about LRU cache behavior in distributed settings
- The implementation is well-structured with proper FSDP2 integration patterns. The main concern is the
@lru_cachedecorator on_get_module_fsdp_state()which could potentially cache stale FSDP state if module state changes during training. The core tensor operations, quantizer handling, and test coverage are solid. No critical bugs identified, though the cache issue warrants monitoring in production. transformer_engine/pytorch/distributed.py- monitor_get_module_fsdp_state()LRU cache behavior during resharding operations
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/pytorch/tensor/float8_tensor.py | 4/5 | Adds FSDP2 support with pre/post all-gather hooks, improves torch dispatch for view/split/new_zeros/as_strided ops with transpose cache handling |
| transformer_engine/pytorch/tensor/mxfp8_tensor.py | 4/5 | Implements FSDP2 support and torch dispatch handlers for split/as_strided/copy_/slice operations, adds pre/post all-gather hooks for MX format |
| transformer_engine/pytorch/distributed.py | 3/5 | Adds _get_module_fsdp_state helper with LRU cache for retrieving FSDP state from modules - cache may cause stale state issues |
| transformer_engine/pytorch/module/base.py | 4/5 | Improves DTensor parameter handling in weight initialization, adds FSDP group support to workspace management |
| transformer_engine/pytorch/module/linear.py | 5/5 | Skips quantizer configuration when weight is already quantized (FSDP2 pre-quantized weights), removes redundant columnwise usage update |
Sequence Diagram
sequenceDiagram
participant App as Application
participant FSDP as FSDP2
participant TEModule as TE Module
participant QTensor as Quantized Tensor<br/>(Float8/MXFP8)
participant Quantizer as Quantizer
Note over App,Quantizer: Initialization Phase
App->>TEModule: Create with FP8 init
TEModule->>Quantizer: Create quantizer (Delayed/Current/MX)
TEModule->>QTensor: Quantize weights
App->>FSDP: fully_shard(module)
FSDP->>QTensor: Shard weights
Note over App,Quantizer: Forward Pass (Training)
App->>FSDP: forward()
FSDP->>QTensor: fsdp_pre_all_gather()
QTensor->>QTensor: Set usage based on training state
QTensor->>FSDP: Return (sharded_tensors, metadata)
FSDP->>FSDP: All-gather sharded tensors
FSDP->>QTensor: fsdp_post_all_gather(outputs, metadata)
QTensor->>QTensor: Reconstruct full tensor
QTensor->>QTensor: update_usage() for transpose cache
QTensor->>FSDP: Return allgathered QTensor
FSDP->>TEModule: Forward with full weights
TEModule->>TEModule: Check if weight is QuantizedTensor
alt Weight already quantized
TEModule->>TEModule: Skip quantizer setup
TEModule->>QTensor: Use existing quantizer
else Weight not quantized
TEModule->>Quantizer: Set usage flags
TEModule->>Quantizer: Quantize weight
end
TEModule->>App: Return output
Note over App,Quantizer: Backward Pass
App->>FSDP: backward()
FSDP->>QTensor: fsdp_pre_all_gather() (PRE_BACKWARD)
QTensor->>QTensor: Set columnwise usage for dgrad
QTensor->>FSDP: Return (sharded_tensors, metadata)
FSDP->>FSDP: All-gather for backward
FSDP->>QTensor: fsdp_post_all_gather()
QTensor->>QTensor: Reconstruct with transpose data
FSDP->>TEModule: Backward computation
TEModule->>FSDP: Return gradients
FSDP->>FSDP: Reduce-scatter gradients
FSDP->>QTensor: Update sharded weights
11 files reviewed, no comments
|
/te-ci pytorch |
Signed-off-by: Varun Thumbe <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
Adds FSDP2 (Fully Sharded Data Parallel 2) support for FP8 and MXFP8 quantized tensors in PyTorch Transformer Engine, enabling distributed training with FP8 mixed-precision.
Key Changes:
- Implemented
fsdp_pre_all_gatherandfsdp_post_all_gatherhooks for bothFloat8TensorandMXFP8Tensorto handle FSDP weight sharding/gathering lifecycle - Added custom
__torch_dispatch__handlers for FSDP-required operations:aten.split.Tensor,aten.new_zeros,aten.as_strided,aten.copy_, andaten.slice.Tensor - Enhanced transpose caching logic to properly maintain transposed views through various tensor operations
- Added training state-aware quantizer usage control (rowwise vs columnwise) based on forward/backward pass detection
Major Implementation Details:
- FSDP2 integration distinguishes between forward/backward passes using
TrainingState.PRE_BACKWARDto selectively gather only needed tensor representations - For MXFP8, operations validate 128-byte alignment constraints and fall back to dequantization when constraints aren't met
- Transpose cache maintenance across splits, views, and resharding ensures performance optimization for Hopper/L40 architectures
Issues Found:
Multiple critical None-handling bugs exist in MXFP8 dispatch handlers where operations assume non-None data/scale tensors, which would cause AttributeError at runtime when certain usage flags are disabled.
Confidence Score: 3/5
- This PR has several critical runtime issues that need resolution before merging, particularly around None-handling in MXFP8 tensor operations
- Score reflects multiple logic bugs identified by previous reviewers (AttributeError, NameError, variable shadowing) that would cause runtime failures in MXFP8 operations. While Float8Tensor changes appear more robust, MXFP8Tensor has ~8-10 critical None-handling issues across split, slice, copy, and post_all_gather operations that need fixes
transformer_engine/pytorch/tensor/mxfp8_tensor.pyrequires significant attention for None-handling fixes across all new dispatch handlers before this can safely merge
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/pytorch/tensor/float8_tensor.py | 4/5 | Adds FSDP2 support with fsdp_pre_all_gather and fsdp_post_all_gather hooks, implements aten.split.Tensor, aten.new_zeros, and aten.as_strided handlers with transpose caching for FP8 tensors |
| transformer_engine/pytorch/tensor/mxfp8_tensor.py | 3/5 | Implements FSDP2 support and multiple torch dispatch handlers (split, as_strided, copy_, slice, new_zeros) for MXFP8 tensors; contains several critical None-handling issues that need resolution |
Sequence Diagram
sequenceDiagram
participant FSDP2
participant Float8Tensor/MXFP8Tensor
participant Quantizer
participant DeviceMesh
Note over FSDP2: Forward Pass (weights needed)
FSDP2->>Float8Tensor/MXFP8Tensor: fsdp_pre_all_gather(mesh, module, ...)
Float8Tensor/MXFP8Tensor->>Quantizer: check training_state & reshard_after_forward
Quantizer->>Quantizer: set_usage(rowwise=True, columnwise=False)
Float8Tensor/MXFP8Tensor->>FSDP2: return (sharded_data, metadata)
FSDP2->>DeviceMesh: all_gather(sharded_data)
DeviceMesh->>FSDP2: all_gather_outputs
FSDP2->>Float8Tensor/MXFP8Tensor: fsdp_post_all_gather(outputs, metadata, ...)
Float8Tensor/MXFP8Tensor->>Float8Tensor/MXFP8Tensor: reconstruct full tensor
Float8Tensor/MXFP8Tensor->>Float8Tensor/MXFP8Tensor: update_usage(rowwise=True)
Float8Tensor/MXFP8Tensor->>FSDP2: return gathered_tensor
Note over FSDP2: Compute forward pass
Note over FSDP2: Backward Pass (gradients computed)
FSDP2->>Float8Tensor/MXFP8Tensor: fsdp_pre_all_gather(mesh, module, ...)
Float8Tensor/MXFP8Tensor->>Quantizer: check training_state (PRE_BACKWARD)
Quantizer->>Quantizer: set_usage(rowwise=False, columnwise=True)
Float8Tensor/MXFP8Tensor->>FSDP2: return (transpose_data, metadata)
FSDP2->>DeviceMesh: all_gather(transpose_data)
DeviceMesh->>FSDP2: all_gather_outputs
FSDP2->>Float8Tensor/MXFP8Tensor: fsdp_post_all_gather(outputs, metadata, ...)
Float8Tensor/MXFP8Tensor->>Float8Tensor/MXFP8Tensor: reconstruct with transpose
Float8Tensor/MXFP8Tensor->>Float8Tensor/MXFP8Tensor: update_usage(columnwise=True)
Float8Tensor/MXFP8Tensor->>FSDP2: return gathered_tensor
2 files reviewed, 1 comment
Signed-off-by: Varun Thumbe <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR significantly enhances the FSDP2 test infrastructure for Transformer Engine by adding comprehensive support for FP8 mixed-precision training with distributed sharding.
Key Changes:
- Expanded FP8 recipe support: added
Float8CurrentScalingandMXFP8BlockScalingalongside existingDelayedScaling - Introduced flexible layer configuration system supporting 5 TE layer types (Linear, LayerNormLinear, LayerNormMLP, MultiheadAttention, TransformerLayer)
- Added meta device initialization workflow for deferred parameter materialization after FSDP2 sharding
- Implemented
test_fp8_fsdp2_allgather()validation function to verify FP8 allgather correctness against manual FP32 allgather - Enhanced custom attribute save/restore logic to handle
QuantizedTensormetadata correctly with FSDP2 DTensors - Replaced simple 3-layer network with configurable multi-layer architecture supporting both
reshard_after_forward=True/Falsetest cases
The test file is well-structured with clear separation of concerns: model initialization, FSDP2 setup, training loop, and validation logic.
Confidence Score: 5/5
- This PR is safe to merge with high confidence - the changes are well-tested, properly structured, and add comprehensive FSDP2 support.
- Score reflects thorough implementation with proper error handling, comprehensive test coverage of multiple FP8 recipes and layer types, correct FSDP2 integration patterns (save/restore custom attrs, DTensor handling), and validation logic to verify FP8 allgather correctness. The code follows established patterns and includes clear documentation.
- No files require special attention - the test file is comprehensive and correctly implements FSDP2 FP8 support.
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| tests/pytorch/distributed/run_fsdp2_model.py | 5/5 | Comprehensive FSDP2 test script adding support for multiple FP8 recipes, flexible layer configurations, meta device initialization, and FP8 allgather validation |
Sequence Diagram
sequenceDiagram
participant Main as Main Process
participant Init as Model Init
participant FSDP as FSDP2 Sharding
participant Train as Training Loop
participant Test as FP8 Test
Main->>Main: Parse args & setup distributed
Main->>Init: Create FP8 recipe (delayed/current/mx_fp8)
alt FP8 Init Enabled
Init->>Init: fp8_model_init(recipe)
else FP8 Init Disabled
Init->>Init: nullcontext()
end
Init->>Init: init_te_model(config)
Note over Init: Create model on meta/cuda device
Init->>FSDP: save_custom_attrs(model)
Note over FSDP: Save QuantizedTensor metadata
FSDP->>FSDP: get_device_mesh(world_size, sharding_dims)
Note over FSDP: Setup FSDP or HSDP mesh
FSDP->>FSDP: shard_model_with_fsdp2(model, mesh)
Note over FSDP: Apply fully_shard to children & root
FSDP->>FSDP: restore_custom_attrs(model, custom_attrs)
Note over FSDP: Restore FP8 metadata to DTensors
alt Meta Device Init
FSDP->>FSDP: reset_parameters()
Note over FSDP: Materialize sharded params on cuda
end
FSDP->>Train: Create optimizer
loop For each iteration
Train->>Train: Generate input & target
Train->>Train: Forward with te.autocast(recipe)
Train->>Train: Compute loss
Train->>Train: Backward pass
Train->>Train: Optimizer step
end
alt FP8 Init Enabled
Train->>Test: test_fp8_fsdp2_allgather(model)
Test->>Test: Manual FP32 allgather
Test->>Test: FSDP2 FP8 allgather (unshard)
Test->>Test: Validate both match
Test->>Test: Reshard model
end
Main->>Main: Destroy process group
1 file reviewed, no comments
Signed-off-by: Varun Thumbe <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR adds FSDP2 (Fully Sharded Data Parallel v2) support to Transformer Engine by enabling DTensor parameter handling in the base module.
Key Changes:
- Modified
register_parameter()to prevent overwriting FP8-specific metadata when FSDP2 re-registers parameters as DTensors - Enhanced
reset_parameters()to detect and handle DTensor parameters by operating on their local tensors - Added device mesh integration for
Float8CurrentScalingQuantizerto configure amax reduction groups for distributed training - Implemented proper DTensor reconstruction after meta-device materialization
- Ensured quantized local tensors are correctly wrapped back into DTensor parameters
Integration Points:
- DTensor detection via
isinstance(param, DTensor)check - Local tensor extraction and manipulation via
param._local_tensor - Device mesh group configuration for FP8 scaling synchronization across shards
- Parameter wrapping preserves both DTensor structure and FP8 quantization
Confidence Score: 4/5
- This PR is safe to merge with minor considerations for edge cases in DTensor handling
- The implementation correctly handles DTensor parameter registration and initialization. The logic properly distinguishes between DTensor and regular tensors, extracts local tensors for processing, and reconstructs DTensors with appropriate device mesh configuration. The amax reduction group setup for Float8CurrentScalingQuantizer is correctly conditioned on both DTensor type and quantizer type. However, the score is 4 instead of 5 because: (1) the high-precision init value methods are attached to local tensors which relies on DTensor's attribute delegation pattern, and (2) there's no explicit validation that
dtensor_parammaintains valid device_mesh/placements attributes throughout the flow, though the logic appears sound - No files require special attention beyond standard FSDP2 testing with FP8 quantization enabled
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/pytorch/module/base.py | 4/5 | Added FSDP2 DTensor support in parameter registration and reset, including proper handling of local tensors, device mesh configuration for FP8 quantization, and parameter wrapping |
Sequence Diagram
sequenceDiagram
participant FSDP2 as FSDP2
participant Module as TransformerEngineBaseModule
participant ResetParams as reset_parameters()
participant Quantizer as Float8CurrentScalingQuantizer
participant DTensor as DTensor
FSDP2->>Module: register_parameter(name, DTensor)
Note over Module: Check if param_init_meta exists<br/>Only initialize once to preserve FP8 kwargs
Module->>Module: Store param_init_meta[name]
FSDP2->>ResetParams: Trigger parameter initialization
ResetParams->>ResetParams: Check if param is DTensor
ResetParams->>DTensor: Extract _local_tensor
alt Parameter on meta device
ResetParams->>ResetParams: Create empty_like on cuda
ResetParams->>DTensor: Reconstruct DTensor.from_local()<br/>with device_mesh & placements
end
ResetParams->>ResetParams: Apply init_fn to local tensor
alt FP8 quantization enabled
ResetParams->>Quantizer: Configure quantizer settings
alt Is DTensor && Float8CurrentScaling
ResetParams->>DTensor: Get device_mesh
ResetParams->>Quantizer: Set amax_reduction_group<br/>from device_mesh.get_group()
ResetParams->>Quantizer: Enable with_amax_reduction
end
ResetParams->>Quantizer: Quantize local tensor
Quantizer-->>ResetParams: Return QuantizedTensor
end
ResetParams->>DTensor: Update _local_tensor with quantized tensor
ResetParams->>DTensor: Wrap as nn.Parameter
ResetParams->>Module: setattr(name, DTensor parameter)
1 file reviewed, no comments
Signed-off-by: Varun Thumbe <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
Enables FSDP2 training with FP8/MXFP8 initialized weights by implementing custom allgather hooks (fsdp_pre_all_gather and fsdp_post_all_gather) that serialize FP8 tensors to uint8 for distributed communication and reconstruct them post-allgather.
Key Changes:
- FP8 Allgather Support: Float8Tensor and MXFP8Tensor now implement FSDP2 hooks that return uint8 data with metadata (scale_inv, dtype, quantizer) for allgather, enabling FP8 communication instead of high-precision
- Selective Usage Based on Training State: Pre-allgather hooks optimize memory by gathering only rowwise data for forward pass and columnwise data for backward pass when
reshard_after_forward=True - DTensor Integration:
TransformerEngineBaseModule.reset_parameters()now handles FSDP2's DTensor parameters by operating on_local_tensorand preserving FP8 metadata across parameter re-registration - Transpose Cache Management: Enhanced
__torch_dispatch__handlers for split/view/new_zeros/as_strided ops to maintain transpose caches for both data and data_transpose, improving performance - Amax Reduction Setup: Quantizers are configured with appropriate reduction groups for synchronized scale updates across FSDP shards
Issues Found:
- Potential tensor unpacking bug in
mxfp8_tensor.py:613-617where both[:2]and[-2:]slicing could select duplicate tensors if validation fails
Confidence Score: 4/5
- This PR is largely safe to merge with one logical issue that needs verification in edge cases
- The implementation is well-structured and addresses a significant feature gap (FSDP2 support for FP8 weights). The core allgather hook logic is sound and properly handles the forward/backward pass distinction. However, there's a potential edge-case bug in MXFP8Tensor's
fsdp_post_all_gatherwhere tensor unpacking could fail if the tuple length doesn't match usage flags, though this is unlikely in normal operation since the pre/post hooks are paired - transformer_engine/pytorch/tensor/mxfp8_tensor.py - verify the tensor unpacking logic at lines 613-617 handles all edge cases correctly
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/pytorch/tensor/float8_tensor.py | 4/5 | Adds FSDP2 support with fsdp_pre_all_gather and fsdp_post_all_gather hooks, implements FP8 allgather by returning uint8 data with metadata for reconstruction, enhances __torch_dispatch__ to handle transpose caching for split/view/new_zeros/as_strided ops |
| transformer_engine/pytorch/tensor/mxfp8_tensor.py | 3/5 | Adds FSDP2 allgather hooks for MXFP8 tensors with selective rowwise/columnwise data gathering based on training state, implements torch dispatch handlers for split/as_strided/copy_/slice/new_zeros ops with MXFP8 block scaling constraints, has potential tensor unpacking issue in fsdp_post_all_gather |
| transformer_engine/pytorch/distributed.py | 4/5 | Adds _get_module_fsdp_state helper with LRU caching to retrieve FSDP state from modules or their closest FSDP parent |
| transformer_engine/pytorch/module/base.py | 4/5 | Updates reset_parameters to handle DTensor (FSDP2) by operating on _local_tensor, preserves FP8 metadata during FSDP2's re-registration of parameters as DTensors, sets up amax reduction groups for DTensor quantizers |
Sequence Diagram
sequenceDiagram
participant FSDP2 as FSDP2
participant Float8Tensor as Float8Tensor/MXFP8Tensor
participant Quantizer as Quantizer
participant Module as TransformerEngineModule
Note over FSDP2,Module: Forward Pass (or Backward if reshard_after_forward=True)
FSDP2->>Float8Tensor: fsdp_pre_all_gather(mesh, orig_size, module, ...)
Float8Tensor->>Module: _get_module_fsdp_state(module)
Module-->>Float8Tensor: fsdp_state
Float8Tensor->>Quantizer: copy()
Quantizer-->>Float8Tensor: quantizer_copy
alt reshard_after_forward=True
Float8Tensor->>Float8Tensor: Determine forward vs backward from training_state
Float8Tensor->>Quantizer: set_usage(rowwise=!is_backward, columnwise=is_backward)
Note over Float8Tensor: Pack only needed data based on pass direction
else reshard_after_forward=False
Note over Float8Tensor: Pack both rowwise and columnwise if needed
end
Float8Tensor-->>FSDP2: (sharded_uint8_tensors, metadata)
FSDP2->>FSDP2: AllGather uint8 tensors across ranks
FSDP2->>Float8Tensor: fsdp_post_all_gather(all_gather_outputs, metadata, param_dtype, out)
Float8Tensor->>Float8Tensor: Unpack all_gather_outputs and metadata
alt out exists
Float8Tensor->>Float8Tensor: update_usage() on existing tensor
else out is None
Float8Tensor->>Float8Tensor: Construct new Float8Tensor/MXFP8Tensor
Float8Tensor->>Float8Tensor: update_usage() on new tensor
end
Float8Tensor-->>FSDP2: (reconstructed_fp8_tensor, all_gather_outputs)
Note over FSDP2,Module: Tensor ready for forward/backward computation
1 file reviewed, 1 comment
|
/te-ci pytorch |
Signed-off-by: Varun Thumbe <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR makes focused changes to quantized_tensor.py to improve FSDP2 compatibility:
-
Recursive list handling: Added support for recursively updating lists of tensors in in-place operations (lines 436-439). This handles operations like
splitthat return multiple tensors, ensuring QuantizedTensors within lists are properly updated. -
Simplified
make_likemethod: Removed thedataparameter from the base class implementation (lines 493-506). The method now focuses solely on creating views of tensors. This change is safe because:- Subclasses like
Float8Tensoroverride this method and still support thedataparameter for backward compatibility - The base class docstring now correctly reflects that the method is "intended to create view of tensors"
- Existing usages with
data=parameter are handled by the overridden methods in subclasses
- Subclasses like
These are minimal, well-scoped changes that support the broader FSDP2 integration without breaking existing functionality.
Confidence Score: 4/5
- This PR is safe to merge with minimal risk
- The changes are minimal and focused, with only two small modifications to
quantized_tensor.py. The recursive list handling is a straightforward addition that improves robustness. Themake_likesignature change is safe because subclasses override the method and maintain backward compatibility. No issues found that would impact correctness or introduce bugs. - No files require special attention
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/pytorch/quantized_tensor.py | 4/5 | Added recursive list handling for in-place operations and simplified make_like method by removing data parameter. Changes are minimal and focused on improving FSDP2 compatibility. |
Sequence Diagram
sequenceDiagram
participant FSDP as FSDP2 Framework
participant PreHook as fsdp_pre_allgather
participant QT as QuantizedTensor
participant PostHook as fsdp_post_allgather
Note over FSDP: Forward/Backward Pass Begins
FSDP->>PreHook: Call pre_allgather hook
PreHook->>QT: Extract uint8 data + metadata
Note over QT: For FP8: extract _data tensor<br/>For MXFP8: extract rowwise/columnwise data
QT-->>PreHook: Return (uint8_tensors, metadata)
PreHook-->>FSDP: Return allgather input
Note over FSDP: Perform AllGather on uint8 data
FSDP->>PostHook: Call post_allgather hook
PostHook->>QT: Reconstruct from allgathered data
Note over QT: Rebuild Float8/MXFP8 tensor<br/>from uint8 + metadata
QT-->>PostHook: Return reconstructed tensor
PostHook-->>FSDP: Return full tensor
Note over FSDP: Continue computation with full tensor
1 file reviewed, no comments
Signed-off-by: Varun Thumbe <[email protected]>
|
/te-ci pytorch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
Enables FSDP2 distributed training with FP8-initialized weights by implementing allgather hooks, torch dispatch operations, and DTensor support for TE quantized tensors.
Key Changes:
- Implements
fsdp_pre_all_gatherandfsdp_post_all_gatherhooks for Float8Tensor and MXFP8Tensor to enable 8-bit weight allgather (instead of high-precision) - Adds torch dispatch support for FSDP2 tensor operations:
split,copy_,slice,view,as_strided,new_zeros - Updates
reset_parametersin TransformerEngineBaseModule to handle DTensor for deferred initialization (meta device) - Fixes optimizer weight updates by recursively handling lists of quantized tensors in in-place operations
- Moves quantizer usage validation from forward to backward pass to support phase-aware allgather
- Configures amax reduction groups for current scaling quantizer to synchronize scale inverses across FSDP shards
- Comprehensive test coverage for multiple TE layers with delayed/current/MX_FP8 scaling recipes
Memory Impact:
- FP8 per-tensor scaling reduces memory footprint by ~50% vs BF16 on Blackwell (as expected)
- MXFP8 block scaling maintains similar memory to BF16 due to rowwise+columnwise storage requirements
Confidence Score: 4/5
- This PR is mostly safe to merge with one critical bug fix needed in MXFP8 shape handling
- Score reflects solid implementation with comprehensive test coverage, but deducted 1 point due to critical bug in mxfp8_tensor.py:658 where both rowwise/columnwise data can be None causing AttributeError, and minor concerns about LRU cache causing potential memory leaks
- transformer_engine/pytorch/tensor/mxfp8_tensor.py:658 requires fix for None handling
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/pytorch/tensor/float8_tensor.py | 4/5 | Adds FSDP2 allgather hooks and torch dispatch ops (view, split, copy, slice, as_strided, new_zeros) to support 8-bit weight sharding. Implements current scaling quantizer sync for FSDP weight updates. Potential issue with shape handling in line 658. |
| transformer_engine/pytorch/tensor/mxfp8_tensor.py | 3/5 | Implements FSDP2 support with rowwise/columnwise data handling for block-scaled FP8. Adds dispatch ops (split, copy, slice, as_strided, new_zeros). Critical bug at line 658 where both data tensors can be None causing AttributeError. |
| transformer_engine/pytorch/module/base.py | 4/5 | Updates reset_parameters to handle DTensor (FSDP2 deferred init) by quantizing local tensor and reconstructing DTensor. Adds amax_reduction_group configuration for current scaling. Guard prevents metadata loss during DTensor conversion. |
| transformer_engine/pytorch/distributed.py | 4/5 | Adds _get_module_fsdp_state helper with LRU cache to find FSDP state for allgather hooks. Cache could cause memory leaks but likely acceptable given module stability during training. |
| transformer_engine/pytorch/quantized_tensor.py | 5/5 | Fixes in-place ops to recursively handle lists of tensors (optimizer sends batched updates). Removes data parameter from make_like to avoid confusion between view creation and data initialization. |
Sequence Diagram
sequenceDiagram
participant FSDP as FSDP2
participant QT as QuantizedTensor (FP8/MXFP8)
participant Helper as _get_module_fsdp_state
participant Optimizer as Optimizer
Note over FSDP,QT: Forward Pass - Weight Allgather
FSDP->>QT: fsdp_pre_all_gather(module, mesh, ...)
QT->>Helper: Get FSDP state to determine phase
Helper-->>QT: training_state, reshard_after_forward
QT->>QT: Set quantizer.rowwise_usage=True, columnwise=False
QT-->>FSDP: (uint8_data, ...), metadata
FSDP->>FSDP: All-gather uint8 shards
FSDP->>QT: fsdp_post_all_gather(gathered_outputs, metadata)
QT->>QT: Reconstruct FP8 tensor with rowwise usage
QT-->>FSDP: Allgathered FP8 weight
Note over FSDP,QT: Forward Pass Compute
FSDP->>QT: Forward computation with FP8 weights
Note over FSDP,QT: Backward Pass - Weight Allgather (if reshard_after_forward)
FSDP->>QT: fsdp_pre_all_gather(module, mesh, ...)
QT->>Helper: Get FSDP state
Helper-->>QT: training_state=PRE_BACKWARD
QT->>QT: Set quantizer.rowwise=False, columnwise_usage=True
QT-->>FSDP: (uint8_data_transpose, ...), metadata
FSDP->>FSDP: All-gather transpose/columnwise shards
FSDP->>QT: fsdp_post_all_gather(gathered_outputs, metadata)
QT->>QT: Reconstruct FP8 tensor with columnwise usage
QT-->>FSDP: Allgathered FP8 weight
Note over FSDP,Optimizer: Gradient Computation & Weight Update
FSDP->>FSDP: Compute gradients, reduce-scatter
Optimizer->>QT: In-place update (lerp on list of tensors)
QT->>QT: Dequantize, apply op, quantize with amax reduction
Note over QT: Amax synchronized across shards<br/>for current scaling
11 files reviewed, 2 comments
| return inp, handle | ||
|
|
||
|
|
||
| @lru_cache |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: @lru_cache on instance methods can cause memory leaks since module references stay cached. The FSDP state itself is stateful and mutated during training, so caching based on module identity could potentially return stale references if modules are recreated. Consider @lru_cache(maxsize=128) with explicit cache invalidation or verify modules are never recreated during training.
| columnwise_scale_inv=columnwise_scale_inv, | ||
| fp8_dtype=fp8_dtype, | ||
| dtype=param_dtype, | ||
| shape=rowwise_data.shape if rowwise_data is not None else columnwise_data.shape, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: if both rowwise_data and columnwise_data are None (when both usage flags are False), accessing .shape raises AttributeError
| shape=rowwise_data.shape if rowwise_data is not None else columnwise_data.shape, | |
| shape=rowwise_data.shape if rowwise_data is not None else (columnwise_data.shape if columnwise_data is not None else torch.Size([0])), |
Signed-off-by: Varun Thumbe <[email protected]>
|
/te-ci L1 pytorch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR enables FSDP2 training with FP8-initialized weights by implementing custom allgather hooks and torch dispatch handlers. The implementation addresses three key issues: memory footprint with FP8 weights, correct weight updates during training, and efficient 8-bit weight allgather.
Key Changes:
- Implements
fsdp_pre_all_gatherandfsdp_post_all_gatherhooks for Float8Tensor and MXFP8Tensor to handle FP8/MXFP8 allgather using uint8 data - Adds torch dispatch handlers for
split,slice,copy,new_zeros,as_strided, andviewoperations needed by FSDP2 - Uses FSDP state to detect forward vs backward pass and set appropriate rowwise/columnwise quantizer usage
- Sets
amax_reduction_groupfor current scaling quantization to synchronize scale inverses across shards - Updates DTensor parameters correctly during deferred initialization (meta device)
- Moves quantizer usage validation from layer
forward()to_apply_forward/backwardfunctions to accommodate FSDP2's separate allgather for forward/backward
Critical Issues Found:
mxfp8_tensor.py:502andmxfp8_tensor.py:389have potentialAttributeErrorwhen accessing.shapeon tensors that can beNone(when neither rowwise nor columnwise data exists)
Confidence Score: 3/5
- This PR introduces critical bugs that will cause runtime failures in edge cases, but the core FSDP2 integration logic is sound
- Score of 3 reflects two critical logic errors in MXFP8Tensor dispatch handlers (lines 389 and 502) that will cause AttributeError when accessing
.shapeon None values. These bugs occur when quantizer has neither rowwise nor columnwise usage enabled, which may be rare but is not prevented. The rest of the implementation is well-designed with proper handling of forward/backward distinction, amax reduction groups, and DTensor support - transformer_engine/pytorch/tensor/mxfp8_tensor.py lines 389 and 502 require immediate fixes to handle None tensor data
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/pytorch/tensor/mxfp8_tensor.py | 2/5 | Adds FSDP2 torch dispatch handlers for split, slice, copy, new_zeros, as_strided, and view ops. Contains critical bug: line 502 accesses .shape on out_data[0] which can be None when neither rowwise nor columnwise data exists |
| transformer_engine/pytorch/tensor/float8_tensor.py | 3/5 | Adds FSDP2 allgather hooks and torch dispatch handlers for various ops. Implements rowwise/columnwise usage tracking for forward/backward passes. Generally well-structured but relies on cached FSDP state lookup |
| transformer_engine/pytorch/distributed.py | 4/5 | Adds _get_module_fsdp_state helper with @lru_cache to find FSDP state for modules. Cache is appropriate since it stores reference to mutable state object |
| transformer_engine/pytorch/module/base.py | 4/5 | Updates reset_parameters to handle DTensor params for FSDP2 deferred init, sets amax reduction group for current scaling quantization. Logic is sound |
Sequence Diagram
sequenceDiagram
participant FSDP2
participant Float8Tensor
participant MXFP8Tensor
participant Quantizer
participant TE_Module
Note over FSDP2,TE_Module: Forward Pass
FSDP2->>Float8Tensor: fsdp_pre_all_gather(mesh, module, ...)
Float8Tensor->>Float8Tensor: Set amax_reduction_group for current scaling
Float8Tensor->>Float8Tensor: Get FSDP state, check reshard_after_forward
Float8Tensor->>Quantizer: copy() and set_usage(rowwise=True)
Float8Tensor-->>FSDP2: Return (uint8 data, metadata)
FSDP2->>FSDP2: All-gather uint8 data across shards
FSDP2->>Float8Tensor: fsdp_post_all_gather(gathered_data, metadata)
Float8Tensor->>Float8Tensor: Reconstruct with rowwise usage
Float8Tensor-->>FSDP2: Return reconstructed Float8Tensor
FSDP2->>TE_Module: forward(input)
TE_Module->>TE_Module: Validate rowwise usage in _apply_forward
TE_Module-->>FSDP2: output
Note over FSDP2,TE_Module: Backward Pass (if reshard_after_forward)
FSDP2->>Float8Tensor: fsdp_pre_all_gather(mesh, module, ...)
Float8Tensor->>Float8Tensor: Get FSDP state, detect PRE_BACKWARD
Float8Tensor->>Quantizer: copy() and set_usage(columnwise=True)
Float8Tensor-->>FSDP2: Return (uint8 data, metadata)
FSDP2->>FSDP2: All-gather uint8 data across shards
FSDP2->>Float8Tensor: fsdp_post_all_gather(gathered_data, metadata)
Float8Tensor->>Float8Tensor: Reconstruct with columnwise usage
Float8Tensor-->>FSDP2: Return reconstructed Float8Tensor
TE_Module->>TE_Module: Validate columnwise usage in _apply_backward
Note over FSDP2,TE_Module: Weight Initialization (Meta Device)
TE_Module->>TE_Module: reset_parameters() with DTensor
TE_Module->>Quantizer: Set amax_reduction_group from mesh
TE_Module->>Float8Tensor: Initialize with FP8 quantization
TE_Module->>TE_Module: Update DTensor.local_tensor
11 files reviewed, 2 comments
| ) | ||
| out_data.append(scale_inv_out) | ||
| return MXFP8Tensor( | ||
| shape=out_data[0].shape, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: AttributeError if both _rowwise_data and _columnwise_data are None - accessing .shape on out_data[0] will fail when neither usage flag is set
| shape=out_data[0].shape, | |
| shape=out_data[0].shape if out_data[0] is not None else out_data[1].shape, |
| out_data.append(scale_inv_out) | ||
| return [ | ||
| MXFP8Tensor( | ||
| shape=splitted_tensor_data[0].size(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: AttributeError if splitted_tensor_data[0] is None - happens when tensor._rowwise_data is None before splitting
| shape=splitted_tensor_data[0].size(), | |
| shape=splitted_tensor_data[0].size() if splitted_tensor_data[0] is not None else splitted_tensor_data[1].size(), |
Signed-off-by: Varun Thumbe <[email protected]>
…ngine into fsdp2_issue_fix Signed-off-by: Varun Thumbe <[email protected]>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR enables end-to-end FSDP2 training for PyTorch models with Transformer Engine layers initialized with FP8/MXFP8 weights, solving three critical issues: memory footprint problems with FP8-initialized weights, weight update correctness during training, and enabling 8-bit allgather instead of high-precision.
Key Changes:
- FSDP2 Allgather Hooks: Implements
fsdp_pre_all_gatherandfsdp_post_all_gathermethods for Float8Tensor and MXFP8Tensor to support 8-bit weight allgather by returning uint8 data and metadata for reconstruction - Torch Dispatch Operations: Adds handlers for
split,new_zeros,as_strided,copy_,slice, andviewoperations to support FSDP2 sharding and resharding of quantized tensors - FSDP State Management: Introduces
_get_module_fsdp_statehelper with LRU cache to determine forward/backward pass andreshard_after_forwardconfiguration, enabling proper rowwise/columnwise usage selection - Current Scaling Synchronization: Sets amax reduction group in quantizers during allgather to ensure all weight shards share the same scale inverse after optimizer updates
- DTensor Support: Updates
reset_parametersin base module to handle DTensor parameters for FSDP2 deferred initialization with proper quantizer configuration - Quantized Tensor Fixes: Fixes in-place operations to handle lists of tensors (for optimizer lerp operations) and removes incorrect
dataparameter frommake_likeAPI - Usage Validation Refactoring: Moves quantizer usage validation from layer forward to forward/backward functions, and removes unnecessary quantizer updates when weights are already quantized
Memory Impact: FP8 per-tensor quantization reduces memory by 50% vs BF16 on Blackwell. MXFP8 has similar memory footprint to BF16 due to needing both rowwise/columnwise representations.
Test Coverage: Comprehensive tests cover delayed scaling, current scaling, and MX_FP8 block scaling recipes with various layer types (Linear, LayerNormLinear, TransformerLayer) and both FSDP/HSDP configurations.
Confidence Score: 4/5
- Safe to merge with minor considerations - addresses long-standing FSDP2+FP8 issues with comprehensive implementation
- Score of 4 reflects solid implementation with extensive test coverage addressing critical functionality gaps. The changes are well-architected with proper separation between FP8/MXFP8 tensor handling, FSDP2 hooks, and torch dispatch operations. Previous syntax errors in mxfp8_tensor.py mentioned in earlier comments have been fixed. Main concerns are: (1) LRU cache on
_get_module_fsdp_statecould retain module references indefinitely though the return value is a mutable state reference, (2) complex logic for determining forward/backward pass and reshard_after_forward could benefit from additional inline documentation, (3) MXFP8 view operation intentionally falls back to dequantize path with warning when flattening inner dimension. The PR resolves three critical GitHub issues (#1688, #401, #1135, #1188) and includes validation tests. - Pay close attention to
transformer_engine/pytorch/tensor/float8_tensor.pyandtransformer_engine/pytorch/tensor/mxfp8_tensor.pyfor the complex torch dispatch logic and allgather hooks
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/pytorch/distributed.py | 4/5 | Adds _get_module_fsdp_state helper with @lru_cache to find FSDP state for modules - enables determining forward/backward pass during allgather |
| transformer_engine/pytorch/tensor/float8_tensor.py | 4/5 | Implements FSDP2 hooks (fsdp_pre_all_gather, fsdp_post_all_gather) and torch dispatch for split/new_zeros/as_strided/copy operations to support 8-bit allgather |
| transformer_engine/pytorch/tensor/mxfp8_tensor.py | 4/5 | Implements FSDP2 hooks and torch dispatch for MXFP8 tensors with rowwise/columnwise data handling - includes split/as_strided/copy/slice operations |
| transformer_engine/pytorch/quantized_tensor.py | 4/5 | Fixes in-place ops to handle lists of tensors (for optimizer updates) and removes data parameter from make_like to fix view semantics |
| transformer_engine/pytorch/module/base.py | 4/5 | Adds DTensor support in reset_parameters for FSDP2 deferred initialization, handles amax reduction group setup for current scaling quantization |
| transformer_engine/pytorch/module/linear.py | 4/5 | Removes quantizer updates when weight is already quantized, moves columnwise usage validation from forward to backward function |
Sequence Diagram
sequenceDiagram
participant User
participant FSDP2
participant TEModule as TE Module
participant Float8Tensor
participant Quantizer
participant AllGather as FSDP AllGather
User->>FSDP2: Initialize model with fp8_model_init
FSDP2->>TEModule: Create FP8/MXFP8 weight shards
TEModule->>Float8Tensor: Initialize quantized weights
Float8Tensor->>Quantizer: Setup amax reduction group
User->>FSDP2: Start training iteration (forward pass)
FSDP2->>Float8Tensor: fsdp_pre_all_gather(module, mesh)
Float8Tensor->>TEModule: Get FSDP state via _get_module_fsdp_state
Float8Tensor->>Quantizer: Set rowwise usage for forward
Float8Tensor-->>FSDP2: Return (uint8_data,), metadata
FSDP2->>AllGather: AllGather uint8 data across shards
AllGather-->>FSDP2: Gathered uint8 data
FSDP2->>Float8Tensor: fsdp_post_all_gather(outputs, metadata)
Float8Tensor->>Float8Tensor: Reconstruct Float8Tensor with gathered data
Float8Tensor->>Float8Tensor: update_usage(rowwise=True)
Float8Tensor-->>FSDP2: Return reconstructed weight tensor
FSDP2->>TEModule: Forward pass with gathered weights
TEModule->>TEModule: Compute activations
alt reshard_after_forward=True
FSDP2->>FSDP2: Reshard weights after forward
end
User->>FSDP2: Backward pass
alt reshard_after_forward=True
FSDP2->>Float8Tensor: fsdp_pre_all_gather (backward)
Float8Tensor->>Quantizer: Set columnwise usage for backward
Float8Tensor-->>FSDP2: Return appropriate data for backward
FSDP2->>AllGather: AllGather for backward pass
FSDP2->>Float8Tensor: fsdp_post_all_gather
Float8Tensor-->>FSDP2: Reconstructed weight for backward
end
FSDP2->>TEModule: Backward pass with weights
TEModule->>TEModule: Compute gradients
FSDP2->>FSDP2: ReduceScatter gradients
User->>FSDP2: Optimizer step
FSDP2->>Float8Tensor: Update weight shards (via lerp/copy_)
Float8Tensor->>Quantizer: Sync amax across shards
Quantizer->>Quantizer: AllReduce amax for single scale_inv
11 files reviewed, no comments
Description
Motivation:
What this PR does?
Type of change
Changes
FSDP Allgather Hooks for FP8/MXFP8: Adds fsdp_pre_allgather and fsdp_post_allgather methods for for FP8/MXFP8 tensors, since allgather is only supported for native torch tensors with uint8/fp16/bf16/fp32 data types. fsdp_pre_all_gather method for us would return the uint8 sharded tensors for FP8/MXFP8 that we need to allgather and the metadata that is needed to reconstruct the FP8/MXFP8 tensor post allgather. Post_Allgather reconstructs the Float8/MXFP8 tensor from the allgathered uint8 data.
FP8/MXFP8 Torch Dispatch Functions for FSDP2 to handle ops on both rowwise/columnwise data(MXFP8), data/transpose(FP8). NOTE(MXFP8 tensors without padding requirements are only handled. If padding is needed we down the dequantization-compute-quantization route).
Quantized Tensor Class Issues:
Validating rowwise/columnwise Usages for quantizers/tensors in TE Layers
Resetting Parameters for Deferred Initialization(meta device)
Test and Miscellaneous issues
Checklist:
Summary by CodeRabbit
Release Notes